(*
 * Copyright 2014, NICTA
 *
 * This software may be distributed and modified according to the terms of
 * the BSD 2-Clause license. Note that NO WARRANTY is provided.
 * See "LICENSE_BSD2.txt" for details.
 *
 * @TAG(NICTA_BSD)
 *)

(*
 * Information about functions in the program we are translating,
 * and the call-graph between them.
 *)
signature FUNCTION_INFO =
sig
  type fn_info;

  type function_def = {
    name : string,
    args : (string * typ) list,
    return_type : typ,
    const : term,
    raw_const : term,
    definition : thm,
    mono_thm : thm,
    invented_body : bool
  };

  val init_fn_info : Proof.context -> string -> fn_info

  val get_functions : fn_info -> function_def Symtab.table;
  val get_function_def : fn_info -> string -> function_def;
  val get_function_args : fn_info -> string -> (string * typ) list;
  val get_function_from_const : fn_info -> term -> function_def option;

  val map_fn_info : (function_def -> function_def option) -> fn_info -> fn_info;

  val get_function_callees : fn_info -> string -> string list;
  val get_topo_sorted_functions : fn_info -> string list list;

  val fn_def_update_name : string -> function_def -> function_def;
  val fn_def_update_args : (string * typ) list -> function_def -> function_def;
  val fn_def_update_const : term -> function_def -> function_def;
  val fn_def_update_definition : thm -> function_def -> function_def;
  val fn_def_update_return_type : typ -> function_def -> function_def;
  val fn_def_update_mono_thm : thm -> function_def -> function_def;
  val fn_def_update_invented_body : bool -> function_def -> function_def;

  val is_function_recursive : fn_info -> string -> bool;
  val get_recursive_group : fn_info -> string -> string list;
  val get_recursive_functions : fn_info -> string list;
end;

structure FunctionInfo : FUNCTION_INFO =
struct

type function_def = {
    (*  Name of the function. *)
    name : string,

    (* Arguments of the function, in order, excluding measure variables. *)
    args : (string * typ) list,

    (* Return type of the function ("unit" is used for void). *)
    return_type : typ,

    (* Constant for the function, which can be inserted as a call to the
     * function. Unlike "raw_const", this includes any locale parameters
     * required by the function. *)
    const: term,

    (* Raw constant for the function. Existence of this constant in another
     * function's body indicates that that function calls this one. *)
    raw_const: term,

    (* Definition of the function. *)
    definition : thm,

    (* monad_mono theorem for the function. *)
    mono_thm : thm,

    (* Is this function generated by AutoCorres as a placeholder for
     * a function we didn't have the source code to? *)
    invented_body : bool
};

datatype fn_info = FunctionInfo of {
  (* Database of "function_info" records. *)
  function_info : function_def Symtab.table,

  (* Functions directly called by a particular function. *)
  function_callees : string list Symtab.table,

  (* Mapping from "const" back to the function name. (cache) *)
  const_to_function : string Termtab.table,

  (* Topologically sorted functions, based on call graph. (cache) *)
  topo_sorted_functions : string list list,

  (* List of recursive calls a function makes. (cache) *)
  recursive_functions : string list Symtab.table
  };

(*
 * Construct a "fn_info" from a dictionary of function names to "function_def"
 * records.
 *
 * We pre-calculate some information here to avoid having to
 * do it several times later.
 *)
fun mk_function_info fn_info_dict =
let
  (* Construct a dictionary from the constant name of a function to its name. *)
  val const_to_function =
    map (fn (a, b) => (#raw_const b, a)) (Symtab.dest fn_info_dict)
    |> Termtab.make

  (* Get a function's direct callees, based on the list of constants that appear
   * in its definition. *)
  fun get_direct_callees function =
  let
    val body =
      #definition function
      |> Thm.concl_of
      |> Utils.rhs_of
  in
    Term.fold_aterms (fn t => fn a =>
          (Termtab.lookup const_to_function t
          |> Option.map single
          |> the_default []) @ a) body []
    |> distinct (op =)
  end
  val function_callees =
    fn_info_dict
    |> Symtab.dest
    |> map (apsnd get_direct_callees)
    |> Symtab.make

  (*
   * Get a topologically sorted list of functions, based on the call graph.
   *)
  val topo_sorted_functions = let
    (* Get callees and callers of each function, also adding an edge from each
     * function to itself. *)
    val fn_callees =
          function_callees
          |> Symtab.map (fn k => fn l => k :: l)
    val fn_callers = flip_symtab fn_callees
  in
    Topo_Sort.topo_sort {
      cmp = String.compare,
      graph = Symtab.lookup fn_callees #> the,
      converse = Symtab.lookup fn_callers #> the
    } (Symtab.keys fn_callees |> sort String.compare)
    |> map (sort String.compare)
  end

  (* Does a function call itself? *)
  fun is_self_recursive f =
    member (op =) (Symtab.lookup function_callees f |> the) f

  (* Get a dictionary of recursive functions. *)
  val recursive_functions = map (fn x =>
    let
      val in_recursive_group = length x > 1
    in
      if in_recursive_group orelse is_self_recursive (hd x) then
        (map (fn y => (y, x)) x)
      else
        ([(hd x, [])])
    end) topo_sorted_functions
    |> (fn x => fold (fn a => fn b => b @ a) x [])
    |> Symtab.make
in
  FunctionInfo {
    function_info = fn_info_dict,
    function_callees = function_callees,
    const_to_function = const_to_function,
    topo_sorted_functions = topo_sorted_functions,
    recursive_functions = recursive_functions
  }
end


(* Generate a "function_info" from the C Parser's output. *)
fun init_fn_info ctxt filename =
let
  val thy = Proof_Context.theory_of ctxt
  val prog_info = ProgramInfo.get_prog_info ctxt filename
  val csenv = #csenv prog_info

  (* Get information about a single function. *)
  fun gen_fn_info name (return_ctype, _, carg_list) =
  let
    (* Convert C Parser return type into a HOL return type. *)
    val return_type =
      if return_ctype = Absyn.Void then
        @{typ unit}
      else
        CalculateState.ctype_to_typ (thy, return_ctype);

    (* Convert arguments into a list of (name, HOL type) pairs. *)
    val arg_list = map (fn v => (
        ProgramAnalysis.get_mname v |> MString.dest,
        CalculateState.ctype_to_typ (thy, ProgramAnalysis.get_vi_type v)
        )) carg_list

    (*
     * Get constant, type signature and definition of the function.
     *
     * The definition may not exist if the function is declared "extern", but
     * never defined. In this case, we replace the body of the function with
     * what amounts to a "fail" command. Any C body is a valid refinement of
     * this, allowing our abstraction to succeed.
     *)
    val const = Utils.get_term ctxt (name ^ "_'proc")
    val myvars_typ = #state_type prog_info
    val (definition, invented) =
        (Proof_Context.get_thms ctxt (name ^ "_body_def"), false)
        handle ERROR _ =>
          ([instantiate' [SOME (Thm.ctyp_of ctxt myvars_typ)] []
                             @{thm undefined_function_body_def}], true)
  in
    {
      name = name,
      args = arg_list,
      return_type = return_type,
      const = const,
      raw_const = const,
      definition = hd definition,
      mono_thm = @{thm TrueI}, (* placeholder *)
      invented_body = invented
    }
  end
in
  ProgramAnalysis.get_fninfo csenv
  |> Symtab.dest
  |> map (uncurry gen_fn_info)
  |> map (fn x => (#name x, x))
  |> Symtab.make
  |> mk_function_info
end

(*
 * Misc getters for function information.
 *)

fun get_functions (FunctionInfo x) = (#function_info x);

fun get_function_def fn_info name =
    Symtab.lookup (get_functions fn_info) name
    |> Utils.the' ("Function " ^ quote name ^ " does not exist")

fun get_function_args fn_info name =
    get_function_def fn_info name |> #args

fun get_function_from_const (fn_info as FunctionInfo x) term =
    Termtab.lookup (#const_to_function x) term
    |> Option.map (get_function_def fn_info)

(*
 * Map function information for each function in the given program.
 *
 * A return result of "NONE" indicates that the function no longer exists (for
 * example, it has been inlined).
 *)
fun map_fn_info f (FunctionInfo x) =
  Symtab.dest (#function_info x)
  |> map snd
  |> List.mapPartial f
  |> map (fn x => (#name x, x))
  |> Symtab.make
  |> mk_function_info

(* Get a list of functions called by the given function. *)
fun get_function_callees (FunctionInfo functions) fn_name =
  Symtab.lookup (#function_callees functions) fn_name
  |> Utils.the' ("Function " ^ quote fn_name ^ " does not exist.")

(* Get recursive calls made by a function. *)
fun get_recursive_group (FunctionInfo functions) fn_name =
  Symtab.lookup (#recursive_functions functions) fn_name
  |> Utils.the' ("Function " ^ quote fn_name ^ " does not exist.")

(* Is the given function recursive? *)
fun is_function_recursive fn_info fn_name =
  length (get_recursive_group fn_info fn_name) > 0

(* Get recursive calls made by a function. *)
fun get_recursive_functions (FunctionInfo functions) =
  Symtab.dest (#recursive_functions functions)
  |> filter (fn (_, x) => length x <> 0)
  |> map fst

(* Get a list of functions sorted in topological order. Mutually recursive
 * functions remain in the same list. *)
fun get_topo_sorted_functions (FunctionInfo functions) =
  #topo_sorted_functions functions

(* SML record update insanity. *)
fun fn_def_update_name new_val old
  = { name = new_val,   args = #args old, return_type = #return_type old, const = #const old, raw_const = #raw_const old,  definition = #definition old, mono_thm = #mono_thm old, invented_body = #invented_body old }
fun fn_def_update_args new_val old
  = { name = #name old, args = new_val,   return_type = #return_type old, const = #const old, raw_const = #raw_const old,  definition = #definition old, mono_thm = #mono_thm old, invented_body = #invented_body old }
fun fn_def_update_const new_val old
  = { name = #name old, args = #args old, return_type = #return_type old, const = new_val,    raw_const = head_of new_val, definition = #definition old, mono_thm = #mono_thm old, invented_body = #invented_body old }
fun fn_def_update_definition new_val old
  = { name = #name old, args = #args old, return_type = #return_type old, const = #const old, raw_const = #raw_const old,  definition = new_val        , mono_thm = #mono_thm old, invented_body = #invented_body old }
fun fn_def_update_return_type new_val old
  = { name = #name old, args = #args old, return_type = new_val         , const = #const old, raw_const = #raw_const old,  definition = #definition old, mono_thm = #mono_thm old, invented_body = #invented_body old }
fun fn_def_update_mono_thm new_val old
  = { name = #name old, args = #args old, return_type = #return_type old, const = #const old, raw_const = #raw_const old,  definition = #definition old, mono_thm = new_val      , invented_body = #invented_body old}
fun fn_def_update_invented_body new_val old
  = { name = #name old, args = #args old, return_type = #return_type old, const = #const old, raw_const = #raw_const old,  definition = #definition old, mono_thm = #mono_thm old, invented_body = new_val}

end

(* Save function information into the theory. *)
structure AutoCorresFunctionInfo = Theory_Data(
  type T = FunctionInfo.fn_info Symtab.table;
  val empty = Symtab.empty;
  val extend = I;
  fun merge (l, r) =
    Symtab.merge (fn _ => true) (l, r);
)
